//
//  crop.metal
//  EffectMgrMetal
//
//  Created by WS on 2021/5/26.
//  Copyright © 2021 WS. All rights reserved.
//

#include <metal_stdlib>
using namespace metal;

#define CLAMP(v, min, max) \
    if (v < min) { \
        v = min; \
    } else if (v > max) { \
        v = max; \
    }

static float4 GetPixelClamped(texture2d<float, access::read> in [[texture(0)]], uint x, uint y, int inW, int inH) {
    CLAMP(x, 0, inW - 1)
    CLAMP(y, 0, inH - 1)
    return in.read(uint2(x, y));
}

static float Lerp (float A, float B, float t) {
    return A * (1.0f - t) + B * t;
}

static float4 SampleBilinear (texture2d<float, access::read> in [[texture(0)]],
                       float u, float v, int inW, int inH) {
    // calculate coordinates -> also need to offset by half a pixel to keep image from shifting down and left half a pixel
    float x = u * float(inW) - 0.5f;
    int xint = int(x);
    float xfract = x - floor(x);
    
    float y = v * float(inH) - 0.5f;
    int yint = int(y);
    float yfract = y - floor(y);
    
    // get pixels
    auto p00 = GetPixelClamped(in, xint + 0, yint + 0, inW, inH);
    auto p10 = GetPixelClamped(in, xint + 1, yint + 0, inW, inH);
    auto p01 = GetPixelClamped(in, xint + 0, yint + 1, inW, inH);
    auto p11 = GetPixelClamped(in, xint + 1, yint + 1, inW, inH);
    
    // interpolate bi-linearly!
    float4 ret;
    for (int i = 0; i < 4; ++i)
    {
        float col0 = Lerp(p00[i], p10[i], xfract);
        float col1 = Lerp(p01[i], p11[i], xfract);
        float value = Lerp(col0, col1, yfract);
        CLAMP(value, 0.0f, 255.0f);
        ret[i] = value;
    }
    return ret;
}
float4 colorDodge(float4 bgCol, float4 overlay)
{
	float4 outputColor = bgCol / (1.0 - overlay);

	if (overlay.x > 0.99999)
		outputColor.x = 1.0;
	if (overlay.y > 0.99999)
		outputColor.y = 1.0;
	if (overlay.z > 0.99999)
		outputColor.z = 1.0;

	return outputColor;
}

float4 colorBurn(float4 bgCol, float4 overlay)
{

	float4 outputColor = 1.0 - (1.0 - bgCol) / overlay;

	if (overlay.x < 0.000001)
		outputColor.x = 0.0;
	if (overlay.y < 0.000001)
		outputColor.y = 0.0;
	if (overlay.z < 0.000001)
		outputColor.z = 0.0;
	return outputColor;
}

float4 colorDodgeForHardMix(float4 bgCol, float4 overlay)
{
	float4 outputColor = bgCol / (1.0 - overlay);

	if (bgCol.x < 0.000001)
		outputColor.x = 0.0;
	if (bgCol.y < 0.000001)
		outputColor.y = 0.0;
	if (bgCol.z < 0.000001)
		outputColor.z = 0.0;

	return outputColor;
}

float4 colorBurnForHardMix(float4 bgCol, float4 overlay)
{

	float4 outputColor = 1.0 - (1.0 - bgCol) / overlay;

	if (bgCol.x > 0.9999999)
		outputColor.x = 1.0;
	if (bgCol.y > 0.9999999)
		outputColor.y = 1.0;
	if (bgCol.z > 0.9999999)
		outputColor.z = 1.0;
	return outputColor;
}

//tempMatt: matt without alpha
//matt: matt with altph.

float4 blending(float4 backGround, float4 ovl, float matt, float tempMatt, float exeMatt, int blendingMode, float opacity, int ovlAlphaPreMul)
{
	float3 a = float3(0.0, 0.0, 0.0);
	float3 b = float3(0.0, 0.0, 0.0);
	float4 outputColor = float4(0.0, 0.0, 0.0, 0.0);
	float4 overlay = ovl * tempMatt;
	float4 bgCol = backGround;
	float tempOpacity = opacity * matt * exeMatt;
	float invTemOpacity = 1.0 - tempOpacity;
	switch (blendingMode)
	{
	case 0:// normal,
	//bgCol = float4(bgCol.xyz*bgCol.w, bgCol.w);

		outputColor = overlay;
		if (ovlAlphaPreMul == 0)
		{
			outputColor.w = tempOpacity + invTemOpacity * bgCol.w;
			outputColor.xyz = outputColor.xyz*tempOpacity + invTemOpacity * bgCol.xyz;
			return outputColor;
		}
		else {
			outputColor.w = tempOpacity + invTemOpacity * bgCol.w;
			float fOpacity = opacity * tempMatt;
			outputColor.xyz = outputColor.xyz*fOpacity + invTemOpacity * bgCol.xyz;
			return outputColor;
		}
	case 1: // Darken
		outputColor = min(overlay, bgCol);
		break;
	case 2: //multiply
		outputColor = bgCol * overlay;
		break;
	case 3: //  color burn // 1 - (1-Target) / Blend
	{
		float4 temp = (1.0 - bgCol) / overlay;
		if (bgCol.x > 0.99999)
			temp.x = 0.0;
		if (bgCol.y > 0.99999)
			temp.y = 0.0;
		if (bgCol.z > 0.99999)
			temp.z = 0.0;
		outputColor = 1.0 - temp;
	}
	break;
	case 4: // Linear burn
		outputColor = overlay + bgCol - 1.0;
		break;
	case 5: //screen
		outputColor = 1.0 - (1.0 - bgCol)*(1.0 - overlay);
		break;
	case 6: //color dodge
	{
		outputColor = bgCol / (1.0 - overlay);
		if (bgCol.x < 0.00001)
			outputColor.x = 0.0;
		if (bgCol.y < 0.00001)
			outputColor.y = 0.0;
		if (bgCol.z < 0.00001)
			outputColor.z = 0.0;
	}
	break;
	case 7://Linear Dodge
		outputColor = overlay + bgCol;
		break;
	case 8: //overlay // (Target > 0.5) * (1 - (1-2*(Target-0.5)) * (1-Blend)) + (Target <= 0.5) * ((2*Target) * Blend)
	{
		a = float3((bgCol.x > 0.5 ? 1.0 : 0.0), (bgCol.y > 0.5 ? 1.0 : 0.0), (bgCol.z > 0.5 ? 1.0 : 0.0));
		b = float3((bgCol.x <= 0.5 ? 1.0 : 0.0), (bgCol.y <= 0.5 ? 1.0 : 0.0), (bgCol.z <= 0.5 ? 1.0 : 0.0));
		outputColor.xyz = a * (1.0 - (1.0 - 2.0*(bgCol.xyz - 0.5)) * (1.0 - overlay.xyz)) + b * ((2.0*bgCol.xyz) * overlay.xyz);
	}
	break;
	case 9: //Soft Light // 
	{
		a = float3((overlay.x > 0.5 ? 1.0 : 0.0), (overlay.y > 0.5 ? 1.0 : 0.0), (overlay.z > 0.5 ? 1.0 : 0.0));
		b = float3((overlay.x <= 0.5 ? 1.0 : 0.0), (overlay.y <= 0.5 ? 1.0 : 0.0), (overlay.z <= 0.5 ? 1.0 : 0.0));
		outputColor.xyz = a * (2.0*bgCol.xyz*(1.0 - overlay.xyz) + sqrt(bgCol.xyz)*(2.0*overlay.xyz - 1.0)) + b * (2.0*bgCol.xyz*overlay.xyz + bgCol.xyz*bgCol.xyz*(1.0 - 2.0*overlay.xyz));
	}
	break;
	case 10://Hard Light //(Blend > 0.5) * (1 - (1-Target) * (1-2*(Blend-0.5))) + (Blend <= 0.5) * (Target * (2*Blend))
	{
		a = float3(float(overlay.x > 0.5 ? 1.0 : 0.0), float(overlay.y > 0.5 ? 1.0 : 0.0), float(overlay.z > 0.5 ? 1.0 : 0.0));
		b = float3(float(overlay.x <= 0.5 ? 1.0 : 0.0), float(overlay.y <= 0.5 ? 1.0 : 0.0), float(overlay.z <= 0.5 ? 1.0 : 0.0));
		outputColor.xyz = a * (1.0 - (1.0 - bgCol.xyz) * (1.0 - 2.0*(overlay.xyz - 0.5))) + b * (bgCol.xyz * (2.0*overlay.xyz));
	}
	break;
	case 11://vivid light //// (Blend > 0.5) * (1 - (1-Target) / (2*(Blend-0.5))) + (Blend <= 0.5) * (Target / (1-2*Blend))
	{
		a = float3(float(overlay.x > 0.5 ? 1.0 : 0.0), float(overlay.y > 0.5 ? 1.0 : 0.0), float(overlay.z > 0.5 ? 1.0 : 0.0));
		b = float3(float(overlay.x <= 0.5 ? 1.0 : 0.0), float(overlay.y <= 0.5 ? 1.0 : 0.0), float(overlay.z <= 0.5 ? 1.0 : 0.0));
		outputColor.xyz = b * colorBurn(bgCol, (2.0*overlay)).xyz + a * colorDodge(bgCol, (2.0*(overlay - 0.5))).xyz;
	}
	break;
	case 12:// Linear Light//  (Blend > 0.5) * (Target + 2*(Blend-0.5)) + (Blend <= 0.5) * (Target + 2*Blend - 1)
	{
		a = float3(float(overlay.x > 0.5 ? 1.0 : 0.0), float(overlay.y > 0.5 ? 1.0 : 0.0), float(overlay.z > 0.5 ? 1.0 : 0.0));
		b = float3(float(overlay.x <= 0.5 ? 1.0 : 0.0), float(overlay.y <= 0.5 ? 1.0 : 0.0), float(overlay.z <= 0.5 ? 1.0 : 0.0));
		outputColor.xyz = a * (bgCol.xyz + 2.0*(overlay.xyz - 0.5)) + b * (bgCol.xyz + 2.0*overlay.xyz - 1.0);
	}
	break;
	case 13: //PIN Light// (Blend > 0.5) * (max(Target,2*(Blend-0.5))) + (Blend <= 0.5) * (min(Target,2*Blend)))
	{
		a = float3(float(overlay.x > 0.5 ? 1.0 : 0.0), float(overlay.y > 0.5 ? 1.0 : 0.0), float(overlay.z > 0.5 ? 1.0 : 0.0));
		b = float3(float(overlay.x <= 0.5 ? 1.0 : 0.0), float(overlay.y <= 0.5 ? 1.0 : 0.0), float(overlay.z <= 0.5 ? 1.0 : 0.0));
		outputColor.xyz = a * (max(bgCol.xyz, 2.0*(overlay.xyz - 0.5))) + b * (min(bgCol.xyz, 2.0*overlay.xyz));
	}
	break;
	case 14: // hardmix  (VividLight(A,B) < 128) ? 0 : 255
	{
		a = float3(float(overlay.x > 0.5 ? 1.0 : 0.0), float(overlay.y > 0.5 ? 1.0 : 0.0), float(overlay.z > 0.5 ? 1.0 : 0.0));
		b = float3(float(overlay.x <= 0.5 ? 1.0 : 0.0), float(overlay.y <= 0.5 ? 1.0 : 0.0), float(overlay.z <= 0.5 ? 1.0 : 0.0));
		outputColor.xyz = b * colorBurnForHardMix(bgCol, (2.0*overlay)).xyz + a * colorDodgeForHardMix(bgCol, (2.0*(overlay - 0.5))).xyz;
		outputColor.xyz = float3(float(outputColor.x >= 0.5 ? 1.0 : 0.0), float(outputColor.y >= 0.5 ? 1.0 : 0.0), float(outputColor.z >= 0.5 ? 1.0 : 0.0));
		//outputColor.xyz = float3( float(overlay.x + bgCol.x >= 1.0?1.0:0.0), float(overlay.y + bgCol.y >= 1.0?1.0:0.0),float(overlay.z + bgCol.z >= 1.0?1.0:0.0));
	}
	break;
	case 15://Difference
		outputColor = abs(overlay - bgCol);
		break;
	case 16://exclusion // 0.5 - 2*(Target-0.5)*(Blend-0.5)
		outputColor = 0.5 - 2.0*(overlay - 0.5)*(bgCol - 0.5);
		break;
	case 17://Lighten // max(Target,Blend)   
		outputColor = max(overlay, bgCol);
		break;
	case 19: // hollow in 
		outputColor = bgCol * overlay.w;
		outputColor = clamp(outputColor, float4(0.0, 0.0, 0.0, 0.0), float4(1.0, 1.0, 1.0, 1.0));
		break;
	case 20: // hollow out      
		outputColor = bgCol;
		if (bgCol.w < 0.000001)
			outputColor = overlay;
		else
			outputColor = bgCol * (1.0 - overlay.w);
		outputColor = clamp(outputColor, float4(0.0, 0.0, 0.0, 0.0), float4(1.0, 1.0, 1.0, 1.0));
		break;
	case 21: // backGround hollow in 
		outputColor = overlay * bgCol.w;
		outputColor = clamp(outputColor, float4(0.0, 0.0, 0.0, 0.0), float4(1.0, 1.0, 1.0, 1.0));
		break;
	case 22: // replace
		if (tempMatt * exeMatt > 0.0001)
			return float4(overlay.xyz, overlay.w);
		else
			return bgCol;
	case 23://add
		outputColor = bgCol + overlay;
		break;
	default:
		bgCol = float4(bgCol.xyz*bgCol.w, bgCol.w);
		outputColor = overlay;

		if (ovlAlphaPreMul == 0)
		{
			tempOpacity = opacity * matt * exeMatt;
		}
		else {
			tempOpacity = opacity * tempMatt * exeMatt;
		}
		outputColor = clamp(outputColor, float4(0.0, 0.0, 0.0, 0.0), float4(1.0, 1.0, 1.0, 1.0));
		outputColor.w = overlay.w + (1.0 - overlay.w)* bgCol.w;
		outputColor.xyz = outputColor.xyz*tempOpacity + invTemOpacity * bgCol.xyz;
		//outputColor.xyz = clamp( outputColor.xyz / outputColor.w, float3(0.0), float3(1.0) );

		return outputColor;
	}

	outputColor = clamp(outputColor, float4(0.0, 0.0, 0.0, 0.0), float4(1.0, 1.0, 1.0, 1.0));
	outputColor.w = overlay.w + (1.0 - overlay.w)* bgCol.w;
	outputColor.xyz = outputColor.xyz*tempOpacity + invTemOpacity * bgCol.xyz;
	outputColor.xyz = clamp(outputColor.xyz, float3(0.0, 0.0, 0.0), float3(1.0, 1.0, 1.0));

	return outputColor;

}

#define EQN_EPS 1e-9f

static bool isZero(float x) {
	return (x > -EQN_EPS && x < EQN_EPS);
}

kernel void blend(texture2d<float, access::read> overlay [[texture(0)]],
                  texture2d<float, access::read> background [[texture(1)]],
                  texture2d<float, access::write> out [[texture(2)]],
                  constant int *inW [[buffer(0)]],
                  constant int *inH [[buffer(1)]],
                  constant int *outW [[buffer(2)]],
                  constant int *outH [[buffer(3)]],
                  constant float *blend_x[[buffer(4)]],
                  constant float *blend_y[[buffer(5)]],
                  constant int *blendMode[[buffer(6)]],
                  constant float *kRender_Alpha[[buffer(7)]],
                  constant int *ovlAlphaPreMul[[buffer(8)]],
                  uint2 gid [[thread_position_in_grid]])
{
    float u = float(gid.x) / float(*outW - 1);
    float v = float(gid.y) / float(*outH - 1);
    float2 tc = float2(u,v);

    float ow = float(*inW);
    float oh = float(*inH);
    float bw = float(*outW);
    float bh = float(*outH);
    float roi_x0 = *blend_x  * bw;
    float roi_y0 = *blend_y  * bh;
    float roi_x1 = roi_x0 + ow;
    float roi_y1 = roi_y0 + oh;
    roi_x1 = clamp(roi_x1,0.0,bw);
    roi_y1 = clamp(roi_y1,0.0,bh);
    float roi_width = roi_x1 - roi_x0;
    float roi_height = roi_y1 - roi_y0;
    roi_x0 = roi_x0 / bw;
    roi_y0 = roi_y0 / bh;
    roi_x1 = roi_x1 / bw;
    roi_y1 = roi_y1 / bh;

    float over_x0 = 0.0;
    float over_y0 = 0.0;
    float over_x1 = roi_width / ow;
    float over_y1 = roi_height / oh;
    float matt_b = step(roi_x0,tc.x) * step(tc.x,roi_x1) * step(roi_y0,tc.y) * step(tc.y,roi_y1);

    float4 bgCol = SampleBilinear(background, tc.x, tc.y, *outW, *outH) ;//(background, tc.x, tc.y, *inW, *inH)
    if(isZero(matt_b)){
        out.write(bgCol, gid);
        return;
    }
    float resizeCoord_x = (tc.x - roi_x0) * bw / ow ;
    float resizeCoord_y = (tc.y - roi_y0) * bh / oh ;
    float2 resizeCoord = float2(resizeCoord_x,resizeCoord_y);


    float4 ovlCol = SampleBilinear(overlay, resizeCoord.x, resizeCoord.y, *inW, *inH) ;//(overlay, resizeCoord.x, resizeCoord.y, *inW, *inH)
    
//    float4 ovlCol = SampleBilinear(overlay, tc.x, tc.y, *inW, *inH) ;//(overlay, resizeCoord.x, resizeCoord.y, *inW, *inH)
    
	float grid = ovlCol.w;
	float roiMat = matt_b;

	float4 FragColor = blending(bgCol, ovlCol, grid, roiMat, 1.0, *blendMode, *kRender_Alpha, *ovlAlphaPreMul);
    //ovlCol = float4(1.0);
   out.write(FragColor, gid);
}
